###############################################
#### EITC population Audit Disparity Conditional on Noncompliance
###############################################
# import packages and modules
import pandas as pd
import os
import pyodbc
import numpy as np
import sys
from pandas._libs.lib import is_integer
pd.options.display.float_format = '{:.4f}'.format
sys.path.insert(1,'/REDACTED/fairness/code/utilities')
#import cbgCoder
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
import statsmodels.api as sm
from timeit import default_timer as timer

import matplotlib.font_manager as fm
import matplotlib

sys.path.insert(1,'/REDACTED/fairness/code/config')
from configureColors import *

# set plot defaults
plt.style.use('/REDACTED/fairness/code/config/fairness.mplstyle')
fe = fm.FontEntry(
    fname='/REDACTED/fairness/code/config/fonts/cmunrm.ttf',
    name='latex')
fm.fontManager.ttflist.insert(0, fe)
matplotlib.rcParams['font.family'] = fe.name

colorsBNB = cdictBlackNonBlack()
colorsENE = cdictEITCNon()
cb = colorsBNB['Black']['color']
ab = colorsBNB['Black']['alpha']
cnb = colorsBNB['non-Black']['color']
anb = colorsBNB['non-Black']['alpha']
ce = colorsENE['eitc']['color']
cne = colorsENE['non']['color']
ae = colorsENE['eitc']['alpha']
ane = colorsENE['non']['alpha']

def weight_qcut(value, weights, q, **kwargs):
        if is_integer(q):
            quantiles = np.linspace(0,1, q+1)
        else:
            quantiles = q
        order = weights.iloc[value.argsort()].cumsum()
        bins = pd.cut(order/ order.iloc[-1], quantiles, **kwargs)
        return bins.sort_index()

def w_avg(df, values, weights):
	d = df[values]
	w = df[weights]
	return (d * w).sum()/w.sum()      

## research_audits data
research_audits = pd.read_csv('/REDACTED/fairness/code/rf/data/clean_rf_data_plus_dep_database.csv')
research_audits_old = pd.read_csv('/REDACTED/fairness/code/rf/data/merged_rf_data.csv')
#inflation_dict= {2006: 1.18, 2007: 1.16,2008: 1.11,2009: 1.11,2010: 1.08,2011: 1.06,2012: 1.03,2013: 1.02,2014: 1.00}
#research_audits['chg_in_tax_owed_pv'] = [inflation_dict[x]*y for x, y in zip(research_audits.study_year, research_audits.chg_in_tax_owed)]

# do if reading in clean_rf_data_plus_dep_database.csv and not merged_rf_data.csv
research_audits['eitc_ind'] = np.where(research_audits['eitc_amt'] > 0, 1, 0)
research_audits = research_audits[['taxpayer_id', 'predicted_prob_black', 'eitc_ind', 'base_weight', 'study_year', 'ref_cred_amt_dif_pv']]
research_audits = research_audits.drop_duplicates(subset=['taxpayer_id'], keep='first')
research_audits.rename(columns = {"study_year":"tax_yr", "ref_cred_amt_dif_pv":"taxchange", "base_weight":"base_wgt"}, inplace = True)
research_audits = research_audits[research_audits.predicted_prob_black.notna()]




## research_audits EITC
research_audits = research_audits[(research_audits.eitc_ind==1) & (research_audits.base_wgt>0)]
## get bins for research_audits
threshold = 1
research_audits['bin'] = weight_qcut(value = research_audits[research_audits.taxchange>threshold]['taxchange'], weights = research_audits[research_audits.taxchange>threshold]['base_wgt'], q =10, labels = False) +1 
research_audits['bin'] = research_audits.bin.fillna(0)

## research_audits race prob
research_audits['pb']=research_audits.predicted_prob_black * research_audits.base_wgt
research_audits['pnb']=(1-research_audits.predicted_prob_black) * research_audits.base_wgt

### Op audit data 2014
dataBISG = pd.read_csv("/REDACTED/data/final/individualBISG2014_full_final.csv")

# for now maybe drop duplicates randomly? fix later?
collections_data = pd.read_csv('/REDACTED/collections_data_data.csv')

len(collections_data)
collections_data = collections_data.drop_duplicates(subset=['taxpayer_id']) # 1,000 people
len(collections_data)

dataBISG_old = dataBISG.copy()

# check dataBISG taxpayer_id vs taxpayer_id_new and type
dataBISG['taxpayer_id_new'] = dataBISG['taxpayer_id_new'].astype('Int64')
collections_data['taxpayer_id'] = collections_data['taxpayer_id'].astype('Int64')
collections_data = collections_data.rename(columns={'taxpayer_id':'taxpayer_id_new'})
len(dataBISG)
len(collections_data)
dataBISG = dataBISG.merge(collections_data, on='taxpayer_id_new', how='left')


## keep op EITC
dataBISG = dataBISG[dataBISG.isEIC == 1]
#dataBISG['taxchange'] = 376.5/348.3*dataBISG.aud_res
# compare below with the 2 aud_res versions
dataBISG['taxchange'] = dataBISG.sum_dec_tax_amt

### list of bins
bin_list = research_audits.groupby('bin').agg({'taxchange':'max'}).taxchange.round().unique().tolist()
bin_list[-1] = 99999999999
bin_list.extend([-9999999999])
bin_list.sort()

### assign bins to op
# does this implicitly remove people who have no taxchange varb? check later
# what about people who have a ref cred change but aren't considered audited in our sample?
dataBISG['bin'] = pd.cut(dataBISG[dataBISG['aud_no_research_audits']==1]['taxchange'], bins = bin_list, labels = False)
dataBISG.bin.value_counts()
dataBISG.bin.notnull().sum()
dataBISG['predicted_prob_nonblack'] = 1 - dataBISG['predicted_prob_black']

def get_est(op = dataBISG):
        overall_ar = op.aud_no_research_audits.mean()
        prob_black_ar = w_avg(op, 'aud_no_research_audits', 'predicted_prob_black')
        prob_nonblack_ar = w_avg(op, 'aud_no_research_audits', 'predicted_prob_nonblack')
        model = sm.OLS.from_formula('aud_no_research_audits ~ predicted_prob_black', data = op).fit(cov_type = 'HC1')
        lin_black_ar = model.params[0] + model.params[1]
        lin_nonblack_ar = model.params[0] 
        return overall_ar, prob_black_ar, prob_nonblack_ar, lin_black_ar, lin_nonblack_ar

def get_aud_rate(research_audits = research_audits, op = dataBISG, bins = bin_list):
        op_aud = op[op.aud_no_research_audits == 1]        
        research_audits_df = research_audits.groupby('bin').agg({'taxchange':'max', 'base_wgt':'sum', 'pb':'sum', 'pnb':'sum'})
        op_df = op_aud.groupby('bin').agg({'base_wgt':'sum', 'predicted_prob_black':'sum', 'predicted_prob_nonblack':'sum'})
        research_audits_df['taxchange'] = research_audits_df.taxchange.round().unique().tolist()
        op_df['taxchange'] = research_audits_df.taxchange.round().unique().tolist()
        research_audits_df.columns = ['taxchange', 'tot_research_audits', 'pb_research_audits', 'pnb_research_audits']
        op_df.columns = ['tot_op', 'pb_op', 'pnb_op', 'taxchange']
        research_audits_df.reset_index(inplace = True)
        op_df.reset_index(inplace = True)
        black_df = pd.concat([research_audits_df[['bin','taxchange','pb_research_audits']], op_df['pb_op']], axis = 1)
        nonblack_df = pd.concat([research_audits_df[['bin','taxchange','pnb_research_audits']], op_df['pnb_op']], axis = 1)
        all_df = pd.concat([research_audits_df[['bin','taxchange','tot_research_audits']], op_df['tot_op']], axis = 1)
        ### GET ESTIMTATES
        overall_ar, prob_black_ar, prob_nonblack_ar, lin_black_ar, lin_nonblack_ar = get_est(op = op)
        ### PROBABILISTIC
        black_df['mean_op_aud_rate_prob'] = prob_black_ar
        nonblack_df['mean_op_aud_rate_prob'] = prob_nonblack_ar
        all_df['mean_op_aud_rate_prob'] = overall_ar
        ### LINEAR
        black_df['mean_op_aud_rate_lin'] = lin_black_ar
        nonblack_df['mean_op_aud_rate_lin'] = lin_nonblack_ar
        all_df['mean_op_aud_rate_lin'] = overall_ar
        dflist = [black_df, nonblack_df, all_df]
        for df in dflist:
                 df.columns = ['bin', 'taxchange', 'research_audits', 'op', 'mean_op_aud_rate_prob', 'mean_op_aud_rate_lin']
                 df['research_audits'] = df['research_audits']/df['research_audits'].sum()
                 df['audited_prob'] = df['op']/df['op'].sum() * df['mean_op_aud_rate_prob']
                 df['audited_lin'] = df['op']/df['op'].sum() * df['mean_op_aud_rate_lin']
                 df['aud_cond_noncomp_prob'] = df['audited_prob']/df['research_audits'] 
                 df['aud_cond_noncomp_lin'] = df['audited_lin']/df['research_audits'] 
        return black_df, nonblack_df, all_df

black_df, nonblack_df, all_df = get_aud_rate(research_audits = research_audits, op = dataBISG, bins = bin_list)


### Bootstrapping SE

def get_sample(research_audits = research_audits, op = dataBISG, frac = 0.1, random_state = 1):
        research_audits_samp = research_audits.sample(frac = frac, weights = research_audits.base_wgt, random_state = random_state, replace = True)
        op_samp = op.sample(frac = frac, random_state = random_state, replace = True)
        b_df,n_df,a_df =  get_aud_rate(research_audits = research_audits_samp, op = op_samp, bins = bin_list)
        return b_df,n_df,a_df

#b_df,n_df,a_df =  get_sample(research_audits = research_audits, op = dataBISG, frac = 0.1)       


### loop over datasets 100 times, calculataxpayer_idg outcomes of interest and resampling data on each iteration
n_iter_bootstrap = 100

b_prob = pd.DataFrame()
nb_prob = pd.DataFrame()
a_prob = pd.DataFrame()
b_lin = pd.DataFrame()
nb_lin = pd.DataFrame()
a_lin = pd.DataFrame()
b_research_audits_bin_share = pd.DataFrame()
nb_research_audits_bin_share = pd.DataFrame()
a_research_audits_bin_share = pd.DataFrame()
for i in range(n_iter_bootstrap):
        start = timer()
        print("Bootstrap iteration: "+str(i))
        b_df,n_df,a_df =  get_sample(research_audits = research_audits, op = dataBISG, frac = 1, random_state= i) 
        b_prob = b_prob.append(b_df['aud_cond_noncomp_prob'], ignore_index=True)
        nb_prob = nb_prob.append(n_df['aud_cond_noncomp_prob'], ignore_index=True)
        a_prob = a_prob.append(a_df['aud_cond_noncomp_prob'], ignore_index=True)
        b_lin = b_lin.append(b_df['aud_cond_noncomp_lin'], ignore_index=True)
        nb_lin = nb_lin.append(n_df['aud_cond_noncomp_lin'], ignore_index=True)
        a_lin = a_lin.append(a_df['aud_cond_noncomp_lin'], ignore_index=True)
        b_research_audits_bin_share = b_research_audits_bin_share.append(b_df['research_audits'], ignore_index=True)
        nb_research_audits_bin_share = nb_research_audits_bin_share.append(n_df['research_audits'], ignore_index=True)
        a_research_audits_bin_share = a_research_audits_bin_share.append(a_df['research_audits'], ignore_index=True)
        end = timer()
        print("Time: "+str(end-start))


black_df['std_err_prob'] = b_prob.std()      
nonblack_df['std_err_prob'] = nb_prob.std() 
all_df['std_err_prob'] = a_prob.std() 

black_df['std_err_lin'] = b_lin.std()      
nonblack_df['std_err_lin'] = nb_lin.std() 
all_df['std_err_lin'] = a_lin.std() 

black_df['std_err_research_audits'] = b_research_audits_bin_share.std()      
nonblack_df['std_err_research_audits'] = nb_research_audits_bin_share.std() 
all_df['std_err_research_audits'] = a_research_audits_bin_share.std()

n_std_dev = 1.96

diff_df = pd.DataFrame()
diff_df['bin'] = black_df['bin']
diff_df['ar_diff_prob'] = black_df.aud_cond_noncomp_prob - nonblack_df.aud_cond_noncomp_prob
diff_df['ar_diff_lin'] = black_df.aud_cond_noncomp_lin - nonblack_df.aud_cond_noncomp_lin
diff_df['std_err_prob'] = (b_prob - nb_prob).std()
diff_df['std_err_lin'] = (b_lin - nb_lin).std()


black_df.to_csv("/REDACTED/bin_black_df_aud_noncomp_collections_data.csv", index=False)
nonblack_df.to_csv("/REDACTED/bin_nonblack_df_aud_noncomp_collections_data.csv", index=False)
diff_df.to_csv("/REDACTED/bin_diff_df_aud_noncomp_collections_data.csv", index=False) 


###############################################
#### Fig 7: Racial Audit Disparity Among EITC Claimants by Underreported Taxes
###############################################

#### PROB
plt.rcParams.update({'errorbar.capsize': 3})
X_axis = np.arange(len(all_df.bin.unique()))

lst_bin = all_df.bin.astype(int).astype(str).unique().tolist()
lst_chg = all_df.taxchange.astype(int).unique().tolist()
lst_chg = [round(i,-1) for i in lst_chg]
lst_chg = ["{0:,}".format(i) for i in lst_chg]
lst_chg = ["$" + str(l) for l in lst_chg]
lst_chg[-1] = "Max"
#X_labels = [i + "\n" + j for i,j in zip(lst_chg, lst_bin)]
X_labels = lst_chg

fig, ax1 = plt.subplots()

plt.xticks(X_axis, X_labels, fontsize = 12)

ax1.bar(X_axis - 0.1, black_df.research_audits, 0.2, label = 'Black share', color = '#53284f', alpha = 1)
ax1.bar(X_axis + 0.1, nonblack_df.research_audits, 0.2, label = 'Non-Black share', color = '#53284f', alpha = 0.4)
ax1.set_ylabel('Share of group in each \noverclaiming bin')
ax2 = ax1.twinx()

ax2.plot(black_df.bin, black_df.aud_cond_noncomp_prob, label = 'Black audit rate', color = '#53284f', alpha = 1, linewidth=4)
ax2.plot(nonblack_df.bin, nonblack_df.aud_cond_noncomp_prob, label = 'Non-Black audit rate', color = '#53284f', alpha = 0.4, linewidth=4)
ax2.errorbar(black_df.bin, black_df.aud_cond_noncomp_prob, yerr=n_std_dev*black_df.std_err_prob , color = '#53284f', alpha = 1, elinewidth=4, capsize=4)
ax2.errorbar(nonblack_df.bin, nonblack_df.aud_cond_noncomp_prob, yerr=n_std_dev*nonblack_df.std_err_prob , color = '#53284f', alpha = 0.4, elinewidth=4, capsize=4)
ax2.set_ylabel('Audit Rate')

ax1.set_xlabel("Overclaiming bin")

ax1.set_ylim(0,0.80)
ax2.set_ylim(0,0.16)

h1, l1 = ax1.get_legend_handles_labels()
h2, l2 = ax2.get_legend_handles_labels()
ax2.legend(h1+h2, l1+l2, loc = (0.125, 0.7))

plt.tight_layout()


vals = ax1.get_yticks()
ax1.set_yticklabels(['{:,.0%}'.format(x) for x in vals])
vals = ax2.get_yticks()
ax2.set_yticklabels(['{:,.0%}'.format(x) for x in vals])


plt.savefig('/REDACTED/aud_noncomp_bin_plot_prob_lw_collections_data.png')
plt.close()



###############################################
#### Fig A.16: Racial Audit Disparity Among EITC Claimants by Underreported Taxes (Linear Estimator)
###############################################

#### LINEAR
plt.rcParams.update({'errorbar.capsize': 3})
X_axis = np.arange(len(all_df.bin.unique()))

lst_bin = all_df.bin.astype(int).astype(str).unique().tolist()
lst_chg = all_df.taxchange.astype(int).unique().tolist()
lst_chg = [round(i,-1) for i in lst_chg]
lst_chg = ["{0:,}".format(i) for i in lst_chg]
lst_chg = ["$" + str(l) for l in lst_chg]
lst_chg[-1] = "Max"
#X_labels = [i + "\n" + j for i,j in zip(lst_chg, lst_bin)]
X_labels = lst_chg

fig, ax1 = plt.subplots()

plt.xticks(X_axis, X_labels, fontsize = 12)


ax1.bar(X_axis - 0.1, black_df.research_audits, 0.2, label = 'Black share', color = '#53284f', alpha = 1)
ax1.bar(X_axis + 0.1, nonblack_df.research_audits, 0.2, label = 'Non-Black share', color = '#53284f', alpha = 0.4)
ax1.set_ylabel('Share of group in each \noverclaiming bin')
ax2 = ax1.twinx()

ax2.plot(black_df.bin, black_df.aud_cond_noncomp_lin, label = 'Black audit rate', color = '#53284f', alpha = 1, linewidth=4)
ax2.plot(nonblack_df.bin, nonblack_df.aud_cond_noncomp_lin, label = 'Non-Black audit rate', color = '#53284f', alpha = 0.4, linewidth=4)
#ax2.plot(diff_df.bin, diff_df.ar_diff_lin, label = 'Difference in audit rate', color = 'black')
ax2.errorbar(black_df.bin, black_df.aud_cond_noncomp_lin, yerr=n_std_dev*black_df.std_err_lin , color = '#53284f', alpha = 1, elinewidth=4, capsize=4)
ax2.errorbar(nonblack_df.bin, nonblack_df.aud_cond_noncomp_lin, yerr=n_std_dev*nonblack_df.std_err_lin , color = '#53284f', alpha = 0.4, elinewidth=4, capsize=4)
#ax2.errorbar(diff_df.bin, diff_df.ar_diff_lin, yerr=n_std_dev*diff_df.std_err_lin , color = 'black')
ax2.set_ylabel('Audit Rate')

ax1.set_xlabel("Overclaiming bin")

ax1.set_ylim(0,0.80)
ax2.set_ylim(0,0.16)

h1, l1 = ax1.get_legend_handles_labels()
h2, l2 = ax2.get_legend_handles_labels()
ax2.legend(h1+h2, l1+l2, loc = (0.125, 0.7))


plt.tight_layout()

vals = ax1.get_yticks()
ax1.set_yticklabels(['{:,.0%}'.format(x) for x in vals])
vals = ax2.get_yticks()
ax2.set_yticklabels(['{:,.0%}'.format(x) for x in vals])


plt.savefig('/REDACTED/aud_noncomp_bin_plot_lin_lw_collections_data.png')
plt.close()